{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "75f9200e-7830-4ee5-8637-e67b5df57eac",
   "metadata": {},
   "source": [
    "# PyTorch Inference Optimizations with Intel® Advanced Matrix Extensions (Intel® AMX)  Bfloat16 Integer8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48eb565f-ef03-40cb-9182-5b2b752331e8",
   "metadata": {},
   "source": [
    "The `PyTorch* Inference Optimizations with Advanced Matrix Extensions Bfloat16 Integer8` sample demonstrates how to perform inference using the ResNet50 and BERT models using the Intel® Extension for PyTorch (IPEX).\n",
    "\n",
    "The Intel® Extension for PyTorch (IPEX) extends PyTorch* with optimizations for extra performance boost on Intel® hardware. While most of the optimizations will be included in future PyTorch* releases, the extension delivers up-to-date features and optimizations for PyTorch on Intel® hardware. For example, newer optimizations include AVX-512 Vector Neural Network Instructions (AVX512 VNNI) and Intel® Advanced Matrix Extensions (Intel® AMX).\n",
    "\n",
    "| Area                  | Description\n",
    "|:---                   |:---\n",
    "| What you will learn   | Inference performance improvements using Intel® Extension for PyTorch (IPEX) with Intel® AMX BF16/INT8\n",
    "| Time to complete      | 5 minutes\n",
    "| Category              | Code Optimization\n",
    "\n",
    "## Purpose\n",
    "\n",
    "The Intel® Extension for PyTorch (IPEX) allows you to speed up inference on Intel® Xeon Scalable processors with lower precision data formats and specialized computer instructions. The bfloat16 (BF16) data format uses half the bit width of floating-point-32 (FP32), which lessens the amount of memory needed and execution time to process. Likewise, the integer8 (INT8) data format uses half the bit width of BF16. You should notice performance optimization with the Intel® AMX instruction set when compared to Intel® Vector Neural Network Instructions (Intel® VNNI).\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "| Optimized for           | Description\n",
    "|:---                     |:---\n",
    "| OS                      | Ubuntu* 18.04 or newer\n",
    "| Hardware                | 4th Gen Intel® Xeon® Scalable Processors or newer\n",
    "| Software                | Intel® Extension for PyTorch (IPEX)\n",
    "\n",
    "## Key Implementation Details\n",
    "\n",
    "This code sample will perform inference on the ResNet50 and BERT models while using Intel® Extension for PyTorch (IPEX). For each pretrained model, there will be a warm up of 20 samples before running inference on the specified number of samples (i.e. 1000) to record the time. Intel® Advanced Matrix Extensions (Intel® AMX) is supported on BF16 and INT8 data types starting with the 4th Generation of Xeon Scalable Processors. The inference time will be compared, showcasing the speedup over FP32 when using AVX-512, Intel® AMX, BF16, and INT8. The following run cases are executed:  \n",
    "\n",
    "1. FP32 (baseline)\n",
    "2. BF16 using AVX512_CORE_AMX\n",
    "3. INT8 using AVX512_CORE_VNNI\n",
    "4. INT8 using AVX512_CORE_AMX\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c254afc",
   "metadata": {},
   "source": [
    "## Installation of required packages\n",
    "\n",
    "Ensure the kernel is set to Pytorch-CPU before running the follwing code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa457cee-5b1e-4ec9-b03a-2a7b2a8b464e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install matplotlib transformers py-cpuinfo sentencepiece sacremoses "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e41ce52-c94c-4bdf-a528-0e0200fd5501",
   "metadata": {},
   "source": [
    "## Imports, Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e4eedf0-5c7c-49d3-be15-f46b4988d9ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from time import time\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import intel_extension_for_pytorch as ipex\n",
    "from intel_extension_for_pytorch.quantization import prepare, convert\n",
    "import torchvision\n",
    "from torchvision import models\n",
    "from transformers import BertModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17246f67-0059-4b5f-afe8-a105d767b139",
   "metadata": {},
   "outputs": [],
   "source": [
    "SUPPORTED_MODELS = [\"resnet50\", \"bert\"]   # models supported by this code sample\n",
    "\n",
    "# ResNet sample data parameters\n",
    "RESNET_BATCH_SIZE = 64\n",
    "\n",
    "# BERT sample data parameters\n",
    "BERT_BATCH_SIZE = 64\n",
    "BERT_SEQ_LENGTH = 512"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9771f165",
   "metadata": {},
   "source": [
    "## Identify Supported ISA  \n",
    "We identify the underlying supported ISA to determine whether Intel® AMX is supported. The 4th Gen Intel® Xeon® Scalable Processor (codenamed Sapphire Rapids) or newer must be used to run this sample.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25c339a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if hardware supports Intel® AMX\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "from cpuinfo import get_cpu_info\n",
    "info = get_cpu_info()\n",
    "flags = info['flags']\n",
    "amx_supported = False\n",
    "for flag in flags:\n",
    "    if \"amx\" in flag:\n",
    "        amx_supported = True\n",
    "        break\n",
    "if not amx_supported:\n",
    "    print(\"Intel® AMX is not supported on current hardware. Code sample cannot be run.\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b3f461d",
   "metadata": {},
   "source": [
    "If the message \"Intel® AMX is not supported on current hardware. Code sample cannot be run.\" is printed above, the hardware being used does not support Intel® AMX. Therefore, this code sample cannot proceed."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ccd66ee-aac5-4a60-8f66-417612d4d3af",
   "metadata": {},
   "source": [
    "## Running Inference\n",
    "The function runInference() will perform inference on the selected model, precision, and whether Intel® AMX is to be enabled. The environment variable `ONEDNN_MAX_CPU_ISA` is used to enable or disable Intel® AMX. **Note that this environment variable is only initialized once.** This means to run with Intel® AMX and Intel® VNNI, there will need to be separate processes. The best practice is to set this environment variable before running your script. For more information, refer to the [oneDNN documentation on CPU Dispatcher Control](https://www.intel.com/content/www/us/en/develop/documentation/onednn-developer-guide-and-reference/top/performance-profiling-and-inspection/cpu-dispatcher-control.html). \n",
    "\n",
    "To use BF16 in operations, use the `torch.cpu.amp.autocast()` function to perform forward pass. For INT8, the quantization feature from Intel® Extension for PyTorch (IPEX) is used to quantize the FP32 model to INT8 before running inference.\n",
    "\n",
    "Torchscript is also utilized to deploy the model in graph mode instead of imperative mode for faster runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f08d718",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"ONEDNN_MAX_CPU_ISA\"] = \"AVX512_CORE_AMX\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b8e21c9-aaa5-4f75-b00a-0d875cc0bfba",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Function to perform inference on Resnet50 and BERT\n",
    "\"\"\"\n",
    "def runInference(model, data, modelName=\"resnet50\", dataType=\"FP32\", amx=True):\n",
    "    \"\"\"\n",
    "    Input parameters\n",
    "        model: the PyTorch model object used for inference\n",
    "        data: a sample input into the model\n",
    "        modelName: str representing the name of the model, supported values - resnet50, bert\n",
    "        dataType: str representing the data type for model parameters, supported values - FP32, BF16, INT8\n",
    "        amx: set to False to disable Intel® AMX  on BF16, Default: True\n",
    "    Return value\n",
    "        inference_time: the time in seconds it takes to perform inference with the model\n",
    "    \"\"\"\n",
    "    \n",
    "    # Display run case\n",
    "    if amx:\n",
    "        isa_text = \"AVX512_CORE_AMX\"\n",
    "    else:\n",
    "        isa_text = \"AVX512_CORE_VNNI\"\n",
    "    print(\"%s %s inference with %s\" %(modelName, dataType, isa_text))\n",
    "\n",
    "    # Special variables for specific models\n",
    "    batch_size = None\n",
    "    if \"resnet50\" == modelName:\n",
    "        batch_size = RESNET_BATCH_SIZE\n",
    "    elif \"bert\" == modelName:\n",
    "        d = torch.randint(model.config.vocab_size, size=[BERT_BATCH_SIZE, BERT_SEQ_LENGTH]) # sample data input for torchscript and inference\n",
    "        batch_size = BERT_BATCH_SIZE\n",
    "    else:\n",
    "        raise Exception(\"ERROR: modelName %s is not supported. Choose from %s\" %(modelName, SUPPORTED_MODELS))\n",
    "\n",
    "    # Prepare model for inference based on precision (FP32, BF16, INT8)\n",
    "    if \"INT8\" == dataType:\n",
    "        # Quantize model to INT8 if needed (one time)\n",
    "        model_filename = \"quantized_model_%s.pt\" %modelName\n",
    "        if not os.path.exists(model_filename):\n",
    "            qconfig = ipex.quantization.default_static_qconfig\n",
    "            prepared_model = prepare(model, qconfig, example_inputs=data, inplace=False)\n",
    "            converted_model = convert(prepared_model)\n",
    "            with torch.no_grad():\n",
    "                if \"resnet50\" == modelName:\n",
    "                    traced_model = torch.jit.trace(converted_model, data)\n",
    "                elif \"bert\" == modelName:\n",
    "                    traced_model = torch.jit.trace(converted_model, (d,), check_trace=False, strict=False)\n",
    "                else:\n",
    "                    raise Exception(\"ERROR: modelName %s is not supported. Choose from %s\" %(modelName, SUPPORTED_MODELS))\n",
    "                traced_model = torch.jit.freeze(traced_model)\n",
    "            traced_model.save(model_filename)\n",
    "\n",
    "        # Load INT8 model for inference\n",
    "        model = torch.jit.load(model_filename)\n",
    "        model.eval()\n",
    "        model = torch.jit.freeze(model)\n",
    "    elif \"BF16\" == dataType:\n",
    "        model = ipex.optimize(model, dtype=torch.bfloat16)\n",
    "        with torch.no_grad():\n",
    "            with torch.cpu.amp.autocast():\n",
    "                if \"resnet50\" == modelName:\n",
    "                    model = torch.jit.trace(model, data)\n",
    "                elif \"bert\" == modelName:\n",
    "                    model = torch.jit.trace(model, (d,), check_trace=False, strict=False)\n",
    "                else:\n",
    "                    raise Exception(\"ERROR: modelName %s is not supported. Choose from %s\" %(modelName, SUPPORTED_MODELS))\n",
    "                model = torch.jit.freeze(model)\n",
    "    else: # FP32\n",
    "        with torch.no_grad():\n",
    "            if \"resnet50\" == modelName:\n",
    "                model = torch.jit.trace(model, data)\n",
    "            elif \"bert\" == modelName:\n",
    "                model = torch.jit.trace(model, (d,), check_trace=False, strict=False)\n",
    "            else:\n",
    "                raise Exception(\"ERROR: modelName %s is not supported. Choose from %s\" %(modelName, SUPPORTED_MODELS))\n",
    "            model = torch.jit.freeze(model)\n",
    "\n",
    "    # Run inference\n",
    "    with torch.no_grad():\n",
    "        if \"BF16\" == dataType:\n",
    "            with torch.cpu.amp.autocast():\n",
    "                # Warm up\n",
    "                for i in range(5):\n",
    "                    model(data)\n",
    "                \n",
    "                # Measure latency\n",
    "                start_time = time()\n",
    "                model(data)\n",
    "                end_time = time()\n",
    "        else:\n",
    "            # Warm up\n",
    "            for i in range(5):\n",
    "                model(data)\n",
    "            \n",
    "            # Measure latency\n",
    "            start_time = time()\n",
    "            model(data)\n",
    "            end_time = time()\n",
    "    inference_time = end_time - start_time\n",
    "    print(\"Inference on batch size %d took %.3f seconds\" %(batch_size, inference_time))\n",
    "\n",
    "    return inference_time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dad2dae",
   "metadata": {},
   "source": [
    "The function summarizeResults() displays the inference times and generates one graph for comparing the inference times and another graph for comparing the speedup using FP32 as the baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cf736a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Prints out results and displays figures summarizing output.\n",
    "\"\"\"\n",
    "def summarizeResults(modelName=\"\", results=None, batch_size=1):\n",
    "    \"\"\"\n",
    "    Input parameters\n",
    "        modelName: a str representing the name of the model\n",
    "        results: a dict with the run case and its corresponding time in seconds\n",
    "        batch_size: an integer for the batch size\n",
    "    Return value\n",
    "        None\n",
    "    \"\"\"\n",
    "\n",
    "    # Inference time results\n",
    "    print(\"\\nSummary for %s (Batch Size = %d)\" %(modelName, batch_size))\n",
    "    for key in results.keys():\n",
    "        print(\"%s inference time: %.3f seconds\" %(key, results[key]))\n",
    "\n",
    "    # Create bar chart with inference time results\n",
    "    plt.figure()\n",
    "    plt.title(\"%s Inference Time (Batch Size = %d)\" %(modelName, batch_size))\n",
    "    plt.xlabel(\"Run Case\")\n",
    "    plt.ylabel(\"Inference Time (seconds)\")\n",
    "    plt.bar(results.keys(), results.values())\n",
    "\n",
    "    # Calculate speedup when using Intel® AMX\n",
    "    print(\"\\n\")\n",
    "    bf16_with_amx_speedup = results[\"FP32\"] / results[\"BF16_with_AMX\"]\n",
    "    print(\"BF16 with Intel® AMX  is %.2fX faster than FP32\" %bf16_with_amx_speedup)\n",
    "    int8_with_vnni_speedup = results[\"FP32\"] / results[\"INT8_with_VNNI\"]\n",
    "    print(\"INT8 without Intel® AMX  is %.2fX faster than FP32\" %int8_with_vnni_speedup)\n",
    "    int8_with_amx_speedup = results[\"FP32\"] / results[\"INT8_with_AMX\"]\n",
    "    print(\"INT8 with Intel® AMX  is %.2fX faster than FP32\" %int8_with_amx_speedup)\n",
    "    print(\"\\n\\n\")\n",
    "\n",
    "    # Create bar chart with speedup results\n",
    "    plt.figure()\n",
    "    plt.title(\"%s Intel® AMX  BF16/INT8 Speedup over FP32\" %modelName)\n",
    "    plt.xlabel(\"Run Case\")\n",
    "    plt.ylabel(\"Speedup\")\n",
    "    plt.bar(results.keys(), \n",
    "        [1, bf16_with_amx_speedup, int8_with_vnni_speedup, int8_with_amx_speedup]\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e42672a",
   "metadata": {},
   "source": [
    "### VNNI: ResNet50 and BERT\n",
    "Since ONEDNN_MAX_CPU_ISA is initialized ONCE when a workload is being run, another process must be used to run with a different setting. \n",
    "In other words, changing ONEDNN_MAX_CPU_ISA during runtime in the same process will not have any effect.\n",
    "Thus, to run with VNNI, a separate script is run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "555ec5a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python python/pytorch_inference_vnni.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d194fa7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Record the inference times for INT8 using AVX-512\n",
    "int8_with_vnni_resnet_inference_time = 0.033   #TODO: enter in inference time\n",
    "int8_with_vnni_bert_inference_time = 0.691    #TODO: enter in inference time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c61288e7",
   "metadata": {},
   "source": [
    "### : ResNet50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4a6a84c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up ResNet50 model and sample data\n",
    "resnet_model = models.resnet50(pretrained=True)\n",
    "resnet_data = torch.rand(RESNET_BATCH_SIZE, 3, 224, 224)\n",
    "resnet_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b26789b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FP32 (baseline)\n",
    "fp32_resnet_inference_time = runInference(resnet_model, resnet_data, modelName=\"resnet50\", dataType=\"FP32\", amx=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ad0c512",
   "metadata": {},
   "outputs": [],
   "source": [
    "# BF16 using Intel® \n",
    "bf16_amx_resnet_inference_time = runInference(resnet_model, resnet_data, modelName=\"resnet50\", dataType=\"BF16\", amx=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd9f1bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# INT8 using Intel® \n",
    "int8_amx_resnet_inference_time = runInference(resnet_model, resnet_data, modelName=\"resnet50\", dataType=\"INT8\", amx=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59fcbbe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Summarize and display results\n",
    "results_resnet = {\n",
    "        \"FP32\": fp32_resnet_inference_time,\n",
    "        \"BF16_with_AMX\": bf16_amx_resnet_inference_time,\n",
    "        \"INT8_with_VNNI\": int8_with_vnni_resnet_inference_time,\n",
    "        \"INT8_with_AMX\": int8_amx_resnet_inference_time\n",
    "    }\n",
    "summarizeResults(\"ResNet50\", results_resnet, RESNET_BATCH_SIZE)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75a62b72",
   "metadata": {},
   "source": [
    "The first graph displays the inference times on the specified number of samples. In general, the times should be decreasing from left to right because using lower precision and with  accelerates the computations. The second graph displays the relative speedup of each run case compared to that of FP32. In general the speedup should be increasing from left to right."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b36fa4b3",
   "metadata": {},
   "source": [
    "###  BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27f173e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up BERT model and sample data\n",
    "bert_model = BertModel.from_pretrained(\"bert-base-uncased\")\n",
    "bert_data = torch.randint(bert_model.config.vocab_size, size=[BERT_BATCH_SIZE, BERT_SEQ_LENGTH])\n",
    "bert_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a5847c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FP32 (baseline)\n",
    "fp32_bert_inference_time = runInference(bert_model, bert_data, modelName=\"bert\", dataType=\"FP32\", amx=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35fc58e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# BF16 using Intel® \n",
    "bf16_amx_bert_inference_time = runInference(bert_model, bert_data, modelName=\"bert\", dataType=\"BF16\", amx=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b3d2ccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# INT8 using Intel® \n",
    "int8_amx_bert_inference_time = runInference(bert_model, bert_data, modelName=\"bert\", dataType=\"INT8\", amx=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3721e698",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Summarize and display results\n",
    "results_bert = {\n",
    "        \"FP32\": fp32_bert_inference_time,\n",
    "        \"BF16_with_AMX\": bf16_amx_bert_inference_time,\n",
    "        \"INT8_with_VNNI\": int8_with_vnni_bert_inference_time,\n",
    "        \"INT8_with_AMX\": int8_amx_bert_inference_time\n",
    "    }\n",
    "summarizeResults(\"BERT\", results_bert, BERT_BATCH_SIZE)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03e63f93",
   "metadata": {},
   "source": [
    "The first graph displays the inference times on the specified number of samples. In general, the times should be decreasing from left to right because using lower precision and with  accelerates the computations. The second graph displays the relative speedup of each run case compared to that of FP32. In general the speedup should be increasing from left to right."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b559aeb8",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0da073a6",
   "metadata": {},
   "source": [
    "This code sample shows how to enable and disable  during runtime, as well as the performance improvements using  BF16 and INT8 for inference on the ResNet50 and BERT models. Performance will vary based on your hardware and software versions. To see a larger performance gap between VNNI and , increase the batch size. For even more speedup, consider using the Intel® Extension for PyTorch (IPEX) [Launch Script](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/launch_script.html). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa0877d6-e045-4091-b5e4-4dfcb6d04f7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('[CODE_SAMPLE_COMPLETED_SUCCESSFULLY]')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "vscode": {
   "interpreter": {
    "hash": "ed6ae0d06e7bec0fef5f1fb38f177ceea45508ce95c68ed2f49461dd6a888a39"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
